import os
import cv2
import glob
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
%matplotlib inline
dir = "/ext/Data/distracted_driver_detection/"
driver_imgs_list_csv = os.path.join(dir, "driver_imgs_list.csv")
df = pd.read_csv(driver_imgs_list_csv)
#driver_list = df.groupby('subject', as_index=False)['img'].count()
driver_list = df.groupby('subject')['img'].count()
print(driver_list)
print("drivers count = %d"%len(driver_list))
sns.countplot(y='subject', data=df, orient="h")
sns.plt.show()
class_list = df.groupby('classname')['img'].count()
print(class_list)
print("classes count = %d"%len(class_list))
sns.countplot(x='classname', data=df)
sns.plt.show()
sns.plt.figure(figsize=(16, 32))
sns.countplot(y='subject', hue='classname', data=df)
sns.plt.show()
def show_images(classname):
images = []
drivers = []
for driver in driver_list.index:
item0 = df[(df["subject"]==driver) & (df["classname"]==classname)].head(1)
image = os.path.join(dir,"train",item0["classname"].values[0],item0["img"].values[0])
drivers.append(driver)
images.append(image)
plt.figure(figsize=(16, 16))
for i in range(len(images)):
plt.subplot(6, 5, i+1)
img = cv2.imread(images[i])
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
plt.title(drivers[i])
plt.axis('off')
plt.imshow(img)
show_images("c0")
show_images("c1")
show_images("c2")
show_images("c3")
show_images("c4")
show_images("c5")
show_images("c6")
show_images("c7")
show_images("c8")
show_images("c9")
images = []
begin = 550
for i in range(20):
item = df[(df["subject"]=="p002")].iloc[begin+i:begin+i+1]
image = os.path.join(dir,"train",item["classname"].values[0],item["img"].values[0])
images.append(image)
print(len(images))
plt.figure(figsize=(16, 10))
for i in range(len(images)):
plt.subplot(4, 5, i+1)
img = cv2.imread(images[i])
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
plt.axis('off')
plt.imshow(img)
begin = 2000
images = glob.glob(os.path.join(dir, "test/test/", "*"))[begin:begin+30]
plt.figure(figsize=(16, 16))
for i in range(len(images)):
plt.subplot(6, 5, i+1)
img = cv2.imread(images[i])
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
plt.axis('off')
plt.imshow(img)